{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5b0fd9e2-e4ee-4252-9738-c75833c034a0",
   "metadata": {},
   "source": [
    "## 将memory插入到提示词模板中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa678d24-3fbf-42cc-ac77-6ed7d8407fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import Tool\n",
    "from langchain.agents import AgentType\n",
    "from langchain.memory import ConversationBufferMemory\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.utilities import SerpAPIWrapper\n",
    "from langchain.agents import initialize_agent\n",
    "from langchain.chains import LLMMathChain\n",
    "from langchain.prompts import  MessagesPlaceholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84af4a21-9deb-4279-8027-1f16f4e57328",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义大模型\n",
    "llm=ChatOpenAI(\n",
    "    temperature=0,\n",
    "    model=\"gpt-4-1106-preview\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7644db91-4c71-41ae-9397-6c3d945e9ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义可使用的工具\n",
    "import os\n",
    "os.environ[\"SERPAPI_API_KEY\"] = \"f265b8d9834ed7692cba6db6618e2a8a9b24ed6964c457296a2626026e8ed594\"\n",
    "\n",
    "#构建一个搜索工具\n",
    "search = SerpAPIWrapper()\n",
    "\n",
    "#创建一个数学计算工具\n",
    "llm_math_chain = LLMMathChain(\n",
    "    llm=llm,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "tools = [\n",
    "    Tool(\n",
    "        name = \"Search\",\n",
    "        func=search.run,\n",
    "        description=\"useful for when you need to answer questions about current events or the current state of the world\"\n",
    "    ),\n",
    "    Tool(\n",
    "        name=\"Calculator\",\n",
    "        func=llm_math_chain.run,\n",
    "        description=\"useful for when you need to answer questions about math\"\n",
    "    ),\n",
    "]\n",
    "print(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c70b30-26cb-4549-8e68-407622d663de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 增加memory组件\n",
    "memory = ConversationBufferMemory(\n",
    "    memory_key=\"chat_history\",\n",
    "    return_messages=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29f846c3-9677-491d-ba78-fd6473e4e488",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义agent\n",
    "agent_chain = initialize_agent(\n",
    "    tools, \n",
    "    llm, \n",
    "    agent=AgentType.OPENAI_FUNCTIONS, \n",
    "    verbose=True, \n",
    "    handle_parsing_errors=True,#处理解析错误\n",
    "    memory=memory #记忆组件\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856a5754-325b-419b-99b7-0f90e57b9b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查看默认的agents prompt是什么样的\n",
    "print(agent_chain.agent.prompt.messages)\n",
    "print(agent_chain.agent.prompt.messages[0])\n",
    "print(agent_chain.agent.prompt.messages[1])\n",
    "print(agent_chain.agent.prompt.messages[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffe1aaf5-9595-4d0e-aa7b-11c922a89ed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 需要使用agent_kwargs传递参数，将chat_history传入\n",
    "agent_chain = initialize_agent(\n",
    "    tools, \n",
    "    llm, \n",
    "    agent=AgentType.OPENAI_FUNCTIONS, \n",
    "    verbose=True, \n",
    "    handle_parsing_errors=True,#处理解析错误\n",
    "    agent_kwargs={\n",
    "        \"extra_prompt_messages\":[MessagesPlaceholder(variable_name=\"chat_history\"), # 历史记录的关键字\n",
    "                                 MessagesPlaceholder(variable_name=\"agent_scratchpad\")],\n",
    "    },\n",
    "    memory=memory #记忆组件\n",
    ")\n",
    "\n",
    "print(agent_chain.agent.prompt.messages)\n",
    "print(agent_chain.agent.prompt.messages[0])\n",
    "print(agent_chain.agent.prompt.messages[1])\n",
    "print(agent_chain.agent.prompt.messages[2])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (langchain)",
   "language": "python",
   "name": "langchain-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
